import numpy
import random
from datetime import datetime,timedelta
from backtest_optimize.Yhlz_Module import *


class geneticAlgorithm():
    """
    遗传算法优化，接受list的被优化的参数，以及它们的评估函数的评估结果，得出新的参数
    """

    def __init__(self, paraList, num=100, learningRate=0.1):
        """

        :param paraList 现有群体的基因参数
        :param num: 种群数量
        :param changeRate: 变异概率
        """
        self.num = num
        self.changeRate = learningRate
        # 需要是矩阵不能是list
        assert isinstance(paraList[0], numpy.ndarray)
        self.paraList = paraList
        self.max = 0.0
        self.min = float('inf')
        for item in self.paraList:
            if self.min > item.min():
                self.min = item.min()
            if self.max < item.max():
                self.max = item.max()

    def _mutation(self):
        """变异操作"""
        changeNum = int(self.paraList[0].size * self.changeRate)
        for item in self.paraList:
            row = numpy.random.randint(0, item.shape[0], changeNum)
            col = numpy.random.randint(0, item.shape[1], changeNum)
            value = numpy.random.rand(changeNum) * (self.max - self.min) + self.min
            item[row, col] = value

    def _cross(self, eva_res):
        """
        交叉操作
        :param eva_res: 评估结果，也就是和para list对应的顺序的个体的适应度
        :return:
        """
        tempPara = []
        weight = 1 / (numpy.array(eva_res)) + 1  # 将他们用反比例函数变为权重
        for i in range(self.num):
            res = random.choices(tempPara, weight, k=2)
            temp = numpy.zeros(res[0].shape)
            pos = random.randint(0, res[0].size)
            temp.ravel()[:pos] = res[0].ravel()[:pos]  # 前面用0，后面用1
            temp.ravel()[pos:] = res[0].ravel()[pos:]
            tempPara.append(temp)
            # TODO 还需要reshape回本来的样子。
        self.paraList = tempPara

    def optimize(self, evaRes):
        """用遗传算法的规则以及评估结果优化这个参数"""
        self._cross(evaRes)
        self._mutation()
        return self.paraList


class geneticAlgorithmInBacktest(geneticAlgorithm):
    """在整个时间周期内用遗传算法进行参数优化来让交易策略的性能变得更好。"""

    def __init__(self, paraList, num=100, learningRate=0.1, start=None, end=None, splitPeriod=30):
        """

        :param paraList:
        :param num: 种群数量
        :param learningRate: 学习/变异的概率。
        :param barCount: 多少根bar触发一次优化。
        :param splitPeriod: 进行遗传计算的时间周期，单位是天
        """

        super(geneticAlgorithmInBacktest, self).__init__(paraList, num, learningRate)
        assert start != None
        assert end != None

        if isinstance(start, str) and isinstance(end, str):
            start = datetime.strptime(start, '%Y-%m-%d %H:%M:%S')
            end = datetime.strptime(end, '%Y-%m-%d %H:%M:%S')
        elif isinstance(start, datetime) and isinstance(end, datetime):
            pass
        else:
            raise Exception('未知的起始终止时间种类！')

        self.splitPeroid = splitPeriod
        tempEnd = start
        self.feedsList = []

        while tempEnd < end:
            if tempEnd + timedelta(days=self.splitPeroid) < end:
                self.feedsList.append(DATA.feed(start=tempEnd, end=tempEnd + timedelta(days=self.splitPeroid))[1])  # 返回两个值，第二个是feed
            elif tempEnd + timedelta(days=self.splitPeroid) >= end:
                self.feedsList.append(DATA.feed(start=tempEnd, end=end)[1])
            tempEnd += timedelta(days=self.splitPeroid)

    def getFeeds(self):
        return self.feedsList

    @classmethod
    def transPara(cls, arry1: np.ndarray, arry2: np.ndarray):
        """将若两个矩阵转换为一个长条形的以便进行各种遗传操作"""
        return np.append(arry1.ravel(), arry2.ravel())

    @classmethod
    def transParaBack(cls, array: np.ndarray, pos=None):
        """将一个矩阵重新拆分成为两个，以便于神经网络进行计算, 如果不设置pos参数，表示从中间切开, 如果是奇数长度，向前取整"""
        if pos:
            temp = array.ravel()
            return temp[:int(temp.size/2)], temp[int(temp.size/2):]
        else:  # 指定了位置的情况
            temp = array.ravel()
            return temp[:int(temp.size / 2)], temp[int(temp.size / 2):]
